# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import numpy as np
from numpy import random as rand


def FrankWolfe_XY(outers, X, Y, iterations, threshold, warm_start):
    #FrankWolfe optimization that takes the Y matrix of differences
    
    old_y_max_val = 1
    lambda_vec = warm_start
    for k in range(1,iterations):
        #compute design
        A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis], axis=0)
        
        #compute pseudo-inverse if singular
        if np.linalg.det(A_lambda) == 0:
            #print("singular")
            cov_A = np.linalg.pinv(A_lambda)
        else:
            cov_A = np.linalg.inv(A_lambda)
            
        #determine max
        y_max = Y[np.argmax(np.diag(Y @ cov_A @ Y.T))] #index of max predictive uncertainty for differences
        y_max_val = np.max(np.diag(Y @ cov_A @ Y.T)) #value of max predictive uncertainty
        lambda_derivative = -(y_max.T @ cov_A @ X.T)**2 #compute derivative 
        
        #update lambda vector
        alpha = 2/(k+2) #step size
        min_lambda_derivative_index = np.argmin(lambda_derivative)
        
        #Frank-Wolfe update
        lambda_vec -= alpha*lambda_vec
        lambda_vec[min_lambda_derivative_index] +=  alpha
        
        if y_max_val == 0 or abs((old_y_max_val - y_max_val)/old_y_max_val) < threshold: #threshold criterion for stopping 
            break
        old_y_max_val = y_max_val #storage for threshold criterion
    return cov_A, y_max_val, lambda_vec

def G_optimal(X, threshold, iterations):
    total_arms = X.shape[0]
    lambda_vec = np.array([1/total_arms]*total_arms)
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :]) 
    return FrankWolfe_XY(outers, X, X, iterations, threshold, lambda_vec.copy()) #get G_optimal-design 

def make_Cov_space(X):
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    Phi = np.zeros([total_arms, phi_d])
    for i in range(total_arms):
        outer_x = np.outer(X[i], X[i])
        dia = np.diag(outer_x)
        outer_triu = outer_x[np.triu_indices(d,1)]
        phi_x = np.concatenate([np.multiply(2,outer_triu), dia])
        #print(phi_x)
        Phi[i] = phi_x
    return Phi

def Sample_X(X, Phi, threshold, iterations, samples):
    total_arms = X.shape[0]
    cov_A_phi, y_max_val_phi, lambda_vec_phi = G_optimal(Phi,threshold,iterations)
    #print(lambda_vec_phi)
    cov_A, y_max_val, lambda_vec = G_optimal(X,threshold,iterations)
    #print(lambda_vec)
    #theta_choice = rand.choice(total_arms, size = (1, samples), p = lambda_vec/np.sum(lambda_vec))[0]
    X_samples = np.ceil(samples*(lambda_vec/np.sum(lambda_vec)))
    theta_choice = sum([ [i]*int(X_samples[i]) for i in range(total_arms)], [])
    #Sigma_choice = rand.choice(total_arms, size = (1,samples), p = lambda_vec_phi/np.sum(lambda_vec_phi))[0]
    Sigma_samples = np.ceil(samples*(lambda_vec_phi/np.sum(lambda_vec_phi)))
    Sigma_choice = sum([ [i]*int(Sigma_samples[i]) for i in range(total_arms)], [])
    X_theta = X[theta_choice]
    X_Sigma = X[Sigma_choice]
    return X_theta, X_Sigma, Sigma_choice, theta_choice


def two_spaces(X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    d = X.shape[1]
    Phi = make_Cov_space(X)
    samples = samples//2
    #print(samples)
    X_theta, X_Sigma, Sigma_choice, theta_choice = Sample_X(X, Phi, threshold, iterations, samples)
    Phi_Sigma = Phi[Sigma_choice]
    
    #pull arms and find thetaHat
#     print(X_theta.shape, Sigma.shape, (X_theta@Sigma * X_theta).sum(-1).shape)
    # X_theta_noise = np.diag(X_theta @ Sigma @ X_theta.T)
    X_theta_noise = (X_theta@Sigma * X_theta).sum(-1)
    samples_X = X_theta.shape[0]
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples_X))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    #pull arms and find SigmaHat
    
    #Find the Sigma_noise another way
    #Sigma_dia = np.diag(Sigma)
    #Sigma_triu = Sigma[np.triu_indices(d,1)]
    #vech_Sigma = np.concatenate([np.multiply(2,Sigma_triu), Sigma_dia])
    #X_Sigma_noise_2 = Phi_Sigma @ vech_Sigma
    
    # X_Sigma_noise = np.diag(X_Sigma @ Sigma @ X_Sigma.T)
    X_Sigma_noise = (X_Sigma @ Sigma * X_Sigma).sum(-1)
    
    samples_Phi = X_Sigma.shape[0]
    phi_d = Phi_Sigma.shape[1]
    Y_Sigma = X_Sigma @ theta + np.multiply(np.sqrt(X_Sigma_noise),rand.randn(samples_Phi))
    Sigma_SE = np.square(Y_Sigma - X_Sigma @ theta_hat_G)
    

    if np.linalg.det(Phi_Sigma.T @ Phi_Sigma) == 0:
        vech_Sigma_hat_G =  np.linalg.inv(Phi_Sigma.T @ Phi_Sigma + np.diag([1]*phi_d)) @ Phi_Sigma.T @ Sigma_SE
    else:
        vech_Sigma_hat_G =  np.linalg.inv(Phi_Sigma.T @ Phi_Sigma) @ Phi_Sigma.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    #w, v = np.linalg.eigh(Sigma_hat_G)
    #w[w < 0] = 0
    #Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return theta_hat_G, Sigma_hat_G



def White_Estimator(X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    total_arms = X.shape[0]
    d = X.shape[1]
    Phi = make_Cov_space(X)
    #cov_A, y_max_val, lambda_vec = G_optimal(X,threshold,iterations)
    choice = rand.choice(total_arms, size = (1, samples))[0]
    Phi_theta = Phi[choice]
    X_theta = X[choice]
    
    #pull arms and find thetaHat
    X_theta_noise = np.diag(X_theta @ Sigma @ X_theta.T)
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    Sigma_SE = np.square(Y_theta - X_theta@ theta_hat_G)
    vech_Sigma_hat_G =  np.linalg.inv(Phi_theta.T @ Phi_theta) @ Phi_theta.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    #w, v = np.linalg.eigh(Sigma_hat_G)
    #w[w < 0] = 0
    #Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return theta_hat_G, Sigma_hat_G

def two_spaces_no_split(X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    d = X.shape[1]
    Phi = make_Cov_space(X)
    X_theta, X_Sigma, Sigma_choice, theta_choice = Sample_X(X, Phi, threshold, iterations, samples//2)
    choice = list(Sigma_choice) + list(theta_choice)
    Phi_theta = Phi[choice]
    X_theta = X[choice]
    samples = len(choice)
    
    #pull arms and find thetaHat
    X_theta_noise = np.diag(X_theta @ Sigma @ X_theta.T)
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    Sigma_SE = np.square(Y_theta - X_theta@ theta_hat_G)
    vech_Sigma_hat_G =  np.linalg.inv(Phi_theta.T @ Phi_theta) @ Phi_theta.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    w, v = np.linalg.eigh(Sigma_hat_G)
    w[w < 0] = 0
    Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return theta_hat_G, Sigma_hat_G

def inv_min_eig_Phi(X, Phi, iterations):
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    current_inv_eig = 1000
    current_choice = np.arange(phi_d)
    for i in range(iterations):
        test_choice = rand.choice(total_arms, size = (1, phi_d), replace = False)[0]
        test_Phi = Phi[test_choice]
        P,D,Q = np.linalg.svd(test_Phi)
        if min(D) != 0:
            test_inv_min_eig = 1/min(D)
            if test_inv_min_eig < current_inv_eig:
                current_inv_eig = test_inv_min_eig
                current_choice = test_choice
    return current_choice, current_inv_eig
        

def Alt_est(X, Sigma, theta, iterations, samples):
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    Phi = make_Cov_space(X)
    current_choice, current_inv_eig =  inv_min_eig_Phi(X, Phi, iterations)
    #print(current_inv_eig)
    M_samples = samples//phi_d
    Phi_samples =  Phi[np.repeat(current_choice, M_samples)]
    Sigma_SE = []
    for i in range(phi_d):
        arms_i = X[[current_choice[i]]*M_samples]
        Y_i = arms_i @ theta + np.multiply([np.sqrt(X[current_choice[i]].T @ Sigma @ X[current_choice[i]])]*M_samples, 
                                           rand.randn(M_samples))
        Y_bar_i = np.mean(Y_i)
        #print(Y_bar_i)
        #print(X[current_choice[i]] @ theta)
        Sigma_SE_i = list(np.square(Y_i - Y_bar_i))
        #print(np.mean(Sigma_SE_i))
        #print(X[current_choice[i]].T @ Sigma @ X[current_choice[i]])
        #Sigma_SE = np.concatenate(Sigma_SE,Sigma_SE_i)
        Sigma_SE += Sigma_SE_i
    #print(current_choice)
    #print(np.repeat(current_choice, M_samples))
    vech_Sigma_hat = np.linalg.inv(Phi_samples.T @ Phi_samples) @ Phi_samples.T @ Sigma_SE
    
    upper_Sigma_hat = np.zeros((d,d))
    upper_Sigma_hat[np.triu_indices(d,1)] = vech_Sigma_hat[0:-d]
    Sigma_hat = upper_Sigma_hat + upper_Sigma_hat.T
    np.fill_diagonal(Sigma_hat, vech_Sigma_hat[-d:len(vech_Sigma_hat)])
    
    #Project onto positive definite cone
    w, v = np.linalg.eigh(Sigma_hat)
    w[w < 0] = 0
    Sigma_hat = v @ np.diag(w) @ v.T
    
    return Sigma_hat

def unit_sphere_samp(arms, d):
    X = rand.multivariate_normal([0]*d, np.diag([1]*d), arms)
    norms = np.linalg.norm(X, axis=1)
    norms.shape = (arms,1)
    return X/norms

def two_spheres_samp(arms, d):
    big_arms = 100
    X = rand.multivariate_normal([0]*d, np.diag([1]*d), big_arms)
    norms = np.linalg.norm(X, axis=1)
    norms.shape = (big_arms,1)
    big_X = X/norms
    
    small_arms = arms - 100
    X = rand.multivariate_normal([0]*d, np.diag([1]*d), small_arms)
    norms = np.linalg.norm(X, axis=1)
    norms.shape = (small_arms,1)
    small_X = X/ (10*norms)
    
    return np.concatenate((big_X, small_X), axis=0)

def run_dim_sim(sims, arms, homogeneous, independent, threshold, iterations, samples):
    dimensions = np.arange(10,20)
    sim_matrix_mean = np.zeros((sims, len(dimensions)))
    sim_matrix_G = np.zeros((sims, len(dimensions)))
    sim_matrix_WE = np.zeros((sims, len(dimensions)))
    sim_matrix_WE_G = np.zeros((sims, len(dimensions)))
    sim_matrix_G_nsp = np.zeros((sims, len(dimensions)))
    
    theta_vec = [1,-1]*10
    
    for j in range(len(dimensions)):
        start = time.time()
        d = dimensions[j]
        theta = theta_vec[:d]
        if homogeneous:
            Sigma_vec = [1.0]*d
        else:
            Sigma_vec = [1.0,2.0]*d
            
        upper_d = int(d*(d+1)/2) - d
        cor_vec = [0.5,0.6,0.7]*100
        upper_tri = cor_vec[:upper_d]
        Sigma = np.diag(Sigma_vec[:d])
        
        if not independent:
            Sigma = np.zeros((d,d))
            Sigma[np.triu_indices(d,1)] = upper_tri
            Sigma = Sigma + Sigma.T
            np.fill_diagonal(Sigma, Sigma_vec[:d])
        print(Sigma)
        print(np.around(np.linalg.inv(Sigma), 2))
        for i in range(sims):
            #sample arms
            X = unit_sphere_samp(arms, d)
            
            #Estimators
            Sigma_hat_mean = Alt_est(X, Sigma, theta, iterations, samples)
            theta_hat_G, Sigma_hat_G = two_spaces(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_WE, Sigma_hat_WE = White_Estimator(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_WE_G, Sigma_hat_WE_G = White_Estimator_G(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_G_nsp, Sigma_hat_G_nsp = two_spaces_no_split(X, Sigma, theta, threshold, iterations, samples)
            
            #storage
            sim_matrix_mean[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_mean @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_G @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_WE[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_WE_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE_G @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_G_nsp[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_G_nsp @ X.T) - np.diag(X @ Sigma @ X.T)))
        print("Time")
        print(time.time() - start)
        print("Dimension")
        print(j)
        
    return Sigma_hat_G, sim_matrix_mean, sim_matrix_G, sim_matrix_WE, sim_matrix_WE_G, sim_matrix_G_nsp

def run_sample_multiproc(sims, arms, homogeneous, independent, threshold, iterations, dimension):
    samples_vec = [1000,6000,11000,16000]
    sim_matrix_mean = np.zeros((sims, len(samples_vec)))
    sim_matrix_G = np.zeros((sims, len(samples_vec)))
    sim_matrix_WE = np.zeros((sims, len(samples_vec)))
    sim_matrix_WE_G = np.zeros((sims, len(samples_vec)))
    sim_matrix_G_nsp = np.zeros((sims, len(samples_vec)))
    d = dimension
    theta_vec = [1,-1]*10
    
    for j in range(len(samples_vec)):
        start = time.time()
        theta = theta_vec[:d]
        if homogeneous:
            Sigma_vec = [1.0]*d
        else:
            Sigma_vec = [1.0,2.0]*d
            
        upper_d = int(d*(d+1)/2) - d
        cor_vec = [0.5,-0.6,0.7]*100
        upper_tri = cor_vec[:upper_d]
        Sigma = np.diag(Sigma_vec[:d])
        
        if not independent:
            Sigma = np.zeros((d,d))
            Sigma[np.triu_indices(d,1)] = upper_tri
            Sigma = Sigma + Sigma.T
            np.fill_diagonal(Sigma, Sigma_vec[:d])
        print(Sigma)
        print(np.around(np.linalg.inv(Sigma), 2))
        samples = samples_vec[j]
        
        Start_internal = time.time()
        #sample arms
        X = unit_sphere_samp(arms, d)

        #Estimators
        #Sigma_hat_mean = Alt_est(X, Sigma, theta, iterations, samples)
        theta_hat_G, Sigma_hat_G = two_spaces(X, Sigma, theta, threshold, iterations, samples)
        theta_hat_WE, Sigma_hat_WE = White_Estimator(X, Sigma, theta, threshold, iterations, samples)
        theta_hat_WE_G, Sigma_hat_WE_G = White_Estimator_G(X, Sigma, theta, threshold, iterations, samples)
        theta_hat_G_nsp, Sigma_hat_G_nsp = two_spaces_no_split(X, Sigma, theta, threshold, iterations, samples)

        #storage
        #sim_matrix_mean[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_mean @ X.T) - np.diag(X @ Sigma @ X.T)))
        sim_matrix_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_G @ X.T) - np.diag(X @ Sigma @ X.T)))
        sim_matrix_WE[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE @ X.T) - np.diag(X @ Sigma @ X.T)))
        sim_matrix_WE_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE_G @ X.T) - np.diag(X @ Sigma @ X.T)))
        sim_matrix_G_nsp[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_G_nsp @ X.T) - np.diag(X @ Sigma @ X.T)))
        print("Time")
        print(time.time() - start)
            
        print("Time")
        print(time.time() - start)
        print("Dimension")
        print(j)
        
    return Sigma_hat_G, sim_matrix_mean, sim_matrix_G, sim_matrix_WE, sim_matrix_WE_G, sim_matrix_G_nsp


def run_sample_sim(sims, arms, homogeneous, independent, threshold, iterations, dimension, two_spheres):
    samples_vec = list(range(100))*1000
    sim_matrix_mean = np.zeros((sims, len(samples_vec)))
    sim_matrix_G = np.zeros((sims, len(samples_vec)))
    sim_matrix_WE = np.zeros((sims, len(samples_vec)))
    d = dimension
    theta_vec = [1,-1]*100
    theta = theta_vec[:d]
    
    if homogeneous:
        Sigma_vec = [1.0]*100
    else:
        Sigma_vec = [0.1,1]*100
    
    Sigma = np.diag(Sigma_vec)
    
    for j in range(len(samples_vec)):
        start = time.time()
        samples = samples_vec[j]
        for i in range(sims):
            Start_internal = time.time()
            #sample arms
            if two_spheres:
                X = two_spheres_samp(arms, d)
            else:
                X = unit_sphere_samp(arms, d)
            
            
            #Estimators
            #Sigma_hat_mean = Alt_est(X, Sigma, theta, iterations, samples)
            theta_hat_G, Sigma_hat_G = two_spaces(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_WE, Sigma_hat_WE = White_Estimator(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_WE_theta_G, Sigma_hat_WE_theta_G = White_Estimator_theta_G(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_WE_Sigma_G, Sigma_hat_WE_Sigma_G = White_Estimator_Sigma_G(X, Sigma, theta, threshold, iterations, samples)
            theta_hat_G_nsp, Sigma_hat_G_nsp = two_spaces_no_split(X, Sigma, theta, threshold, iterations, samples)
            
            #storage
            #sim_matrix_mean[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_mean @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_G @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_WE[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_WE_theta_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE_theta_G @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_WE_Sigma_G[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_WE_Sigma_G @ X.T) - np.diag(X @ Sigma @ X.T)))
            sim_matrix_G_nsp[i,j] = np.mean(np.square(np.diag(X @ Sigma_hat_G_nsp @ X.T) - np.diag(X @ Sigma @ X.T)))
            #print("Time")
            #print(time.time() - Start_internal)
            
        print("Time")
        print(time.time() - start)
        print("Sample")
        print(j)
        
    return Sigma_hat_G, sim_matrix_G, sim_matrix_WE, sim_matrix_G_nsp, sim_matrix_WE_theta_G, sim_matrix_WE_Sigma_G

def parallelize_variance_comp(sims, arms, homogeneous, independent, threshold, iterations, dimension):
    threshold = 0.000001
    iterations = 1000
    dimension = 10
    independent = False
    arms = 10000
    sims = 50
    homogeneous = False
    sims_vec = range(sims)
    num_cores =  multiprocessing.cpu_count()
    pool = Pool(num_cores)
    multiproc_func = functools.partial(run_sample_multiproc, arms, homogeneous, independent, threshold, iterations, dimension)
    return pd.concat(pool.map(multiproc_func, sims_vec))






def White_Estimator_theta_G(X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    total_arms = X.shape[0]
    d = X.shape[1]
    Phi = make_Cov_space(X)
    cov_A, y_max_val, lambda_vec = G_optimal(X,threshold,iterations)
    choice = rand.choice(total_arms, size = (1, samples), p = lambda_vec/np.sum(lambda_vec))[0]
    Phi_theta = Phi[choice]
    X_theta = X[choice]
    
    #pull arms and find thetaHat
    X_theta_noise = np.diag(X_theta @ Sigma @ X_theta.T)
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    Sigma_SE = np.square(Y_theta - X_theta@ theta_hat_G)
    vech_Sigma_hat_G =  np.linalg.inv(Phi_theta.T @ Phi_theta) @ Phi_theta.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    w, v = np.linalg.eigh(Sigma_hat_G)
    w[w < 0] = 0
    Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return theta_hat_G, Sigma_hat_G

def White_Estimator_Sigma_G(X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    total_arms = X.shape[0]
    d = X.shape[1]
    Phi = make_Cov_space(X)
    cov_A, y_max_val, lambda_vec = G_optimal(Phi,threshold,iterations)
    choice = rand.choice(total_arms, size = (1, samples), p = lambda_vec/np.sum(lambda_vec))[0]
    Phi_theta = Phi[choice]
    X_theta = X[choice]
    
    #pull arms and find thetaHat
    X_theta_noise = np.diag(X_theta @ Sigma @ X_theta.T)
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    Sigma_SE = np.square(Y_theta - X_theta@ theta_hat_G)
    vech_Sigma_hat_G =  np.linalg.inv(Phi_theta.T @ Phi_theta) @ Phi_theta.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    w, v = np.linalg.eigh(Sigma_hat_G)
    w[w < 0] = 0
    Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return theta_hat_G, Sigma_hat_G